package rest

import (
	"crypto/tls"
	"errors"
	"fmt"
	"net/http"
	"sort"
	"time"

	"github.com/zeromicro/go-zero/core/codec"
	"github.com/zeromicro/go-zero/core/load"
	"github.com/zeromicro/go-zero/core/logc"
	"github.com/zeromicro/go-zero/core/stat"
	"github.com/zeromicro/go-zero/rest/chain"
	"github.com/zeromicro/go-zero/rest/handler"
	"github.com/zeromicro/go-zero/rest/httpx"
	"github.com/zeromicro/go-zero/rest/internal"
	"github.com/zeromicro/go-zero/rest/internal/header"
	"github.com/zeromicro/go-zero/rest/internal/response"
)

// use 1000m to represent 100%
const topCpuUsage = 1000

// ErrSignatureConfig is an error that indicates bad config for signature.
var ErrSignatureConfig = errors.New("bad config for Signature")

type engine struct {
	conf   RestConf
	routes []featuredRoutes
	// timeout is the max timeout of all routes,
	// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
	// this network timeout is used to avoid DoS attacks by sending data slowly
	// or receiving data slowly with many connections to exhaust server resources.
	timeout              time.Duration
	unauthorizedCallback handler.UnauthorizedCallback
	unsignedCallback     handler.UnsignedCallback
	chain                chain.Chain
	middlewares          []Middleware
	shedder              load.Shedder
	priorityShedder      load.Shedder
	tlsConfig            *tls.Config
}

func newEngine(c RestConf) *engine {
	svr := &engine{
		conf:    c,
		timeout: time.Duration(c.Timeout) * time.Millisecond,
	}

	if c.CpuThreshold > 0 {
		svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
		svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
			(c.CpuThreshold + topCpuUsage) >> 1))
	}

	return svr
}

func (ng *engine) addRoutes(r featuredRoutes) {
	if r.sse {
		r.routes = buildSSERoutes(r.routes)
	}
	ng.routes = append(ng.routes, r)

	ng.mightUpdateTimeout(r)
}

func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
	verifier func(chain.Chain) chain.Chain) chain.Chain {
	if fr.jwt.enabled {
		if len(fr.jwt.prevSecret) == 0 {
			chn = chn.Append(handler.Authorize(fr.jwt.secret,
				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
		} else {
			chn = chn.Append(handler.Authorize(fr.jwt.secret,
				handler.WithPrevSecret(fr.jwt.prevSecret),
				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
		}
	}

	return verifier(chn)
}

func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
	verifier, err := ng.signatureVerifier(fr.signature)
	if err != nil {
		return err
	}

	for _, route := range fr.routes {
		if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
			return err
		}
	}

	return nil
}

func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
	route Route, verifier func(chain.Chain) chain.Chain) error {
	chn := ng.chain
	if chn == nil {
		chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
	}

	chn = ng.appendAuthHandler(fr, chn, verifier)

	for _, middleware := range ng.middlewares {
		chn = chn.Append(convertMiddleware(middleware))
	}
	handle := chn.ThenFunc(route.Handler)

	return router.Handle(route.Method, route.Path, handle)
}

func (ng *engine) bindRoutes(router httpx.Router) error {
	metrics := ng.createMetrics()

	for _, fr := range ng.routes {
		if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
			return err
		}
	}

	return nil
}

func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
	metrics *stat.Metrics) chain.Chain {
	chn := chain.New()

	if ng.conf.Middlewares.Trace {
		chn = chn.Append(handler.TraceHandler(ng.conf.Name,
			route.Path,
			handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
	}
	if ng.conf.Middlewares.Log {
		chn = chn.Append(ng.getLogHandler())
	}
	if ng.conf.Middlewares.Prometheus {
		chn = chn.Append(handler.PrometheusHandler(route.Path, route.Method))
	}
	if ng.conf.Middlewares.MaxConns {
		chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
	}
	if ng.conf.Middlewares.Breaker {
		chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
	}
	if ng.conf.Middlewares.Shedding {
		chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
	}
	if ng.conf.Middlewares.Timeout {
		chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
	}
	if ng.conf.Middlewares.Recover {
		chn = chn.Append(handler.RecoverHandler)
	}
	if ng.conf.Middlewares.Metrics {
		chn = chn.Append(handler.MetricHandler(metrics))
	}
	if ng.conf.Middlewares.MaxBytes {
		chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
	}
	if ng.conf.Middlewares.Gunzip {
		chn = chn.Append(handler.GunzipHandler)
	}

	return chn
}

func (ng *engine) checkedMaxBytes(bytes int64) int64 {
	if bytes > 0 {
		return bytes
	}

	return ng.conf.MaxBytes
}

func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
	if timeout != nil {
		return *timeout
	}

	// if timeout not set in featured routes, use global timeout
	return time.Duration(ng.conf.Timeout) * time.Millisecond
}

func (ng *engine) createMetrics() *stat.Metrics {
	var metrics *stat.Metrics

	if len(ng.conf.Name) > 0 {
		metrics = stat.NewMetrics(ng.conf.Name)
	} else {
		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
	}

	return metrics
}

func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
	if ng.conf.Verbose {
		return handler.DetailedLogHandler
	}

	return handler.LogHandler
}

func (ng *engine) getShedder(priority bool) load.Shedder {
	if priority && ng.priorityShedder != nil {
		return ng.priorityShedder
	}

	return ng.shedder
}

func (ng *engine) hasTimeout() bool {
	return ng.conf.Middlewares.Timeout && ng.timeout > 0
}

// mightUpdateTimeout checks if the route timeout is greater than the current,
// and updates the engine's timeout accordingly.
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
	// if global timeout is set to 0, it means no need to set read/write timeout
	// if route timeout is nil, no need to update ng.timeout
	if ng.timeout == 0 || r.timeout == nil {
		return
	}

	// if route timeout is 0 (means no timeout), cannot set read/write timeout
	if *r.timeout == 0 {
		ng.timeout = 0
		return
	}

	// need to guarantee the timeout is the max of all routes
	// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
	if *r.timeout > ng.timeout {
		ng.timeout = *r.timeout
	}
}

// notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		chn := chain.New(
			handler.TraceHandler(ng.conf.Name,
				"",
				handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
		)

		if ng.conf.Middlewares.Log {
			chn = chn.Append(ng.getLogHandler())
		}

		var h http.Handler
		if next != nil {
			h = chn.Then(next)
		} else {
			h = chn.Then(http.NotFoundHandler())
		}

		cw := response.NewHeaderOnceResponseWriter(w)
		h.ServeHTTP(cw, r)
		cw.WriteHeader(http.StatusNotFound)
	})
}

func (ng *engine) print() {
	var routes []string

	for _, fr := range ng.routes {
		for _, route := range fr.routes {
			routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
		}
	}

	sort.Strings(routes)

	fmt.Println("Routes:")
	for _, route := range routes {
		fmt.Printf("  %s\n", route)
	}
}

func (ng *engine) setTlsConfig(cfg *tls.Config) {
	ng.tlsConfig = cfg
}

func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
	ng.unauthorizedCallback = callback
}

func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
	ng.unsignedCallback = callback
}

func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
	if !signature.enabled {
		return func(chn chain.Chain) chain.Chain {
			return chn
		}, nil
	}

	if len(signature.PrivateKeys) == 0 {
		if signature.Strict {
			return nil, ErrSignatureConfig
		}

		return func(chn chain.Chain) chain.Chain {
			return chn
		}, nil
	}

	decrypters := make(map[string]codec.RsaDecrypter)
	for _, key := range signature.PrivateKeys {
		fingerprint := key.Fingerprint
		file := key.KeyFile
		decrypter, err := codec.NewRsaDecrypter(file)
		if err != nil {
			return nil, err
		}

		decrypters[fingerprint] = decrypter
	}

	return func(chn chain.Chain) chain.Chain {
		if ng.unsignedCallback == nil {
			return chn.Append(handler.LimitContentSecurityHandler(ng.conf.MaxBytes,
				decrypters, signature.Expiry, signature.Strict))
		}

		return chn.Append(handler.LimitContentSecurityHandler(ng.conf.MaxBytes,
			decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
	}, nil
}

func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
	if err := ng.bindRoutes(router); err != nil {
		return err
	}

	// make sure user defined options overwrite default options
	opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)

	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
	}

	// make sure user defined options overwrite default options
	opts = append([]StartOption{
		func(svr *http.Server) {
			if ng.tlsConfig != nil {
				svr.TLSConfig = ng.tlsConfig
			}
		},
	}, opts...)

	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
		ng.conf.KeyFile, router, opts...)
}

func (ng *engine) use(middleware Middleware) {
	ng.middlewares = append(ng.middlewares, middleware)
}

func (ng *engine) withNetworkTimeout() internal.StartOption {
	return func(svr *http.Server) {
		if !ng.hasTimeout() {
			return
		}

		// factor 0.8, to avoid clients send longer content-length than the actual content,
		// without this timeout setting, the server will time out and respond 503 Service Unavailable,
		// which triggers the circuit breaker.
		svr.ReadTimeout = 4 * ng.timeout / 5
		// factor 1.1, to avoid servers don't have enough time to write responses.
		// setting the factor less than 1.0 may lead clients not receiving the responses.
		svr.WriteTimeout = 11 * ng.timeout / 10
	}
}

func buildSSERoutes(routes []Route) []Route {
	for i, route := range routes {
		h := route.Handler
		routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
			// remove the default write deadline set by http.Server,
			// because SSE requires the connection to be kept alive indefinitely.
			rc := http.NewResponseController(w)
			if err := rc.SetWriteDeadline(time.Time{}); err != nil {
				// Some ResponseWriter implementations (like timeoutWriter) don't support SetWriteDeadline.
				// This is expected behavior and doesn't affect SSE functionality.
				logc.Debugf(r.Context(), "unable to clear write deadline for SSE connection: %v", err)
			}

			w.Header().Set(header.ContentType, header.ContentTypeEventStream)
			w.Header().Set(header.CacheControl, header.CacheControlNoCache)
			w.Header().Set(header.Connection, header.ConnectionKeepAlive)
			h(w, r)
		}
	}

	return routes
}

func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return ware(next.ServeHTTP)
	}
}
